package encryption_utils
import (
"bytes"
"crypto/aes"
"crypto/cipher"
"crypto/rand"
"encoding/base64"
"fmt"
"strings"
)
type AES struct {
key []byte
iv []byte
}
func NewAES() *AES {
key := make([]byte, 32)
iv := make([]byte, aes.BlockSize)
if _, err := rand.Read(key); err != nil {
panic(err)
}
if _, err := rand.Read(iv); err != nil {
panic(err)
}
return &AES{
key: key,
iv: iv,
}
}
type EncryptedResponse struct {
EncryptedData string `json:"encryptedData"`
AESProperties string `json:"aesProperties"`
AES bool `json:"aes"`
}
func pkcs7Pad(data []byte, blockSize int) []byte {
padding := blockSize - len(data)%blockSize
padtext := bytes.Repeat([]byte{byte(padding)}, padding)
return append(data, padtext...)
}
func pkcs7Unpad(data []byte, blockSize int) ([]byte, error) {
if len(data) == 0 || len(data)%blockSize != 0 {
return nil, fmt.Errorf("invalid data length")
}
padLen := int(data[len(data)-1])
if padLen == 0 || padLen > blockSize {
return nil, fmt.Errorf("invalid padding length")
}
for i := len(data) - padLen; i < len(data); i++ {
if data[i] != byte(padLen) {
return nil, fmt.Errorf("invalid padding byte")
}
}
return data[:len(data)-padLen], nil
}
func (a *AES) EncryptWithAES(data string, r *RSA) (*EncryptedResponse, error) {
block, err := aes.NewCipher(a.key)
if err != nil {
return nil, err
}
dataInBytes := []byte(data)
padded := pkcs7Pad(dataInBytes, aes.BlockSize)
cbc := cipher.NewCBCEncrypter(block, a.iv)
ciphertext := make([]byte, len(padded))
cbc.CryptBlocks(ciphertext, padded)
ivBase64 := base64.StdEncoding.EncodeToString(a.iv)
keyBase64 := base64.StdEncoding.EncodeToString(a.key)
properties := fmt.Sprintf("%s.%s", keyBase64, ivBase64)
encryptedData := base64.StdEncoding.EncodeToString(ciphertext)
aesProperties, err := r.EncryptUsingPublicKey(properties)
if err != nil {
return nil, err
}
return &EncryptedResponse{
EncryptedData: encryptedData,
AESProperties: aesProperties,
AES: true,
}, nil
}
func (a *AES) DecryptWithAES(encryptedAESProperties, encryptedData string, r *RSA) (string, error) {
decryptedProperties, err := r.DecryptUsingPrivateKey(encryptedAESProperties)
if err != nil {
return "", err
}
parts := strings.Split(decryptedProperties, ".")
if len(parts) != 2 {
return "", fmt.Errorf("invalid AES properties format")
}
iv, err := base64.StdEncoding.DecodeString(parts[1])
if err != nil {
return "", fmt.Errorf("failed to decode IV: %v", err)
}
key, err := base64.StdEncoding.DecodeString(parts[0])
if err != nil {
return "", fmt.Errorf("failed to decode key: %v", err)
}
block, err := aes.NewCipher(key)
if err != nil {
return "", err
}
cbc := cipher.NewCBCDecrypter(block, iv)
if err != nil {
return "", err
}
payload, err := base64.StdEncoding.DecodeString(encryptedData)
if err != nil {
return "", err
}
decryptedData := make([]byte, len(payload))
cbc.CryptBlocks(decryptedData, payload)
decryptedData, err = pkcs7Unpad(decryptedData, aes.BlockSize)
if err != nil {
return "", fmt.Errorf("failed to remove padding: %v", err)
}
return string(decryptedData), nil
}